Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set shape basis method #276

Merged
merged 47 commits into from
Dec 18, 2024
Merged

Set shape basis method #276

merged 47 commits into from
Dec 18, 2024

Conversation

BalzaniEdoardo
Copy link
Collaborator

@BalzaniEdoardo BalzaniEdoardo commented Dec 5, 2024

Add a set_input_shape method that initializes the basis.

The method can accept a list of:

  • Integers: assumes that the input will be flat (if the integer n = 1) or 2D (for n > 1). In the first case the basis allows inputs of shape (n_samples,) in the second (n_samples, n).
  • Tuple of integers: the tuple corresponding to the input shape trimming the first axis (sample axis). If input is (n, m, ...) then the basis expects inputs of shape (n_samples, n, m,...)
  • Arrays: directly the inputs that are going to be provided from compute features.

It prepares the basis and precomputes some internal quantities that are useful for subsequent functionalities: knowing how to split the feature axis, i.e. split_by_features. In the next PR, TransformerBasis will use the pre-computed shapes to split out the concatenated inputs before processing.

Additionally, it will compute and store the number of output features, this could be an information that a user may want to have after building a complex composite basis.

This PR follows #275

EDIT:

PR Summary

In this PR, I refined the class structure to better separate basis attributes and methods, delegating their validation logic to mixin classes wherever possible.

New Classes

  • CompositeBasisMixin:
    This mixin is inherited by additive and multiplicative bases. It implements methods for traversing the composite basis tree, such as __sklearn_clone__ and setup_basis.

  • AtomicBasisMixin:
    Designed for non-composite (atomic) bases, this mixin stores the n_basis_funcs parameter and implements selected methods like __sklearn_clone__, which have uniform implementations across all atomic bases.

New Abstract Methods

  • set_input_shape:
    This method stores the expected input shape, a state attribute required for compatibility with transformers. Parameters set by this method are carried over during cloning, such as in cross-validation. Concrete implementations are provided in AtomicBasisMixin and CompositeBasisMixin.

  • _set_input_independent_states:
    Responsible for setting all state variables that depend on class parameters (provided at initialization and retrievable with get_params).

  • setup_basis:
    Computes all state variables, both input-dependent (e.g., input shape) and input-independent (e.g., kernels for convolutional bases). Concrete implementations are found in Eval, Conv, and Composite basis mixins.

Clone Method for Bases

  • __sklearn_clone__:
    Implements cloning logic to retain input-dependent states (such as input shape), which would otherwise be discarded by sklearn.base.clone. This is implemented in both the CompositeBasisMixin and AtomicBasisMixin.

Modified/Moved Attributes and Methods

  • Kernel-related logic:
    The kernel_ attribute, along with the _check_has_kernel method and set_kernel, has been moved to the Conv mixin.

  • Input shape validation:
    The _check_input_shape_consistency method has been relocated to AtomicBasisMixin and CompositeBasisMixin.

  • Composite basis setters:
    Setters for basis1 and basis2 in composite bases now support cross-validation scenarios.

Inheritance of New Mixins

  • CompositeBasisMixin:
    Inherited by AdditiveBasis and MultiplicativeBasis.

  • AtomicBasisMixin:
    Inherited by the following classes:

    • SplineBasis (the superclass for all splines)
    • RaisedCosineLinear (the superclass for all raised cosines)
    • OrthExponentialBasis (the superclass for orthogonal exponential bases)

@codecov-commenter
Copy link

codecov-commenter commented Dec 5, 2024

Codecov Report

Attention: Patch coverage is 96.59574% with 8 lines in your changes missing coverage. Please review.

Project coverage is 96.13%. Comparing base (3691356) to head (a510ef3).
Report is 2 commits behind head on development.

Files with missing lines Patch % Lines
src/nemos/basis/_transformer_basis.py 78.12% 7 Missing ⚠️
src/nemos/basis/_basis_mixin.py 98.95% 1 Missing ⚠️
Additional details and impacted files
@@              Coverage Diff              @@
##           development     #276    +/-   ##
=============================================
  Coverage        96.13%   96.13%            
=============================================
  Files               34       34            
  Lines             2507     2642   +135     
=============================================
+ Hits              2410     2540   +130     
- Misses              97      102     +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@BalzaniEdoardo BalzaniEdoardo mentioned this pull request Dec 15, 2024
@BalzaniEdoardo
Copy link
Collaborator Author

BalzaniEdoardo commented Dec 16, 2024

In the next PR I'll improve the API of the transformer and move the TransformerBasis tests in a dedicated script.

If you want an overview on how to work with the TransformerBasis before digging into the code, checkout this

https://nemos--276.org.readthedocs.build/en/276/how_to_guide/plot_05_transformer_basis.html

Copy link
Collaborator

@sjvenditto sjvenditto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! I just have a couple of suggested changes to clarify in documentation/error messages where set_input_shape will have an impact. Also, some code in a Warning admonition in plot_06_sklearn_pipeline_cv_demo needs to be updated with the new syntax

src/nemos/basis/_basis.py Outdated Show resolved Hide resolved
src/nemos/basis/_basis_mixin.py Outdated Show resolved Hide resolved
Copy link
Member

@billbrod billbrod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't had a chance to look at tests, but here's my first go through. Some tweaking in the tutorials, and also:

  • I think we can probably move away from using numbers in the names of the tutorials. The only reason to do so is for the automatic sorting to work, but we're no longer making use of automatic sorting, right?
  • I am confused with what lives in the Basis superclass and what lives in AtomicBasisMixin. For example, why does anything to do with n_basis_input or n_basis_funcs live in Basis?

docs/how_to_guide/plot_05_transformer_basis.md Outdated Show resolved Hide resolved
docs/how_to_guide/plot_05_transformer_basis.md Outdated Show resolved Hide resolved
docs/how_to_guide/plot_05_transformer_basis.md Outdated Show resolved Hide resolved
docs/how_to_guide/plot_05_transformer_basis.md Outdated Show resolved Hide resolved
docs/how_to_guide/plot_05_transformer_basis.md Outdated Show resolved Hide resolved
from ._basis import Basis


def set_input_shape_state(method):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not make this more general and accept a list of attributes to copy, which default to ["_n_basis_input_", "_input_shape_"]?

src/nemos/basis/_basis_mixin.py Show resolved Hide resolved
src/nemos/basis/_basis_mixin.py Outdated Show resolved Hide resolved
Comment on lines 584 to 590
@property
def basis1(self):
return self._basis1

@basis1.setter
def basis1(self, bas: Basis):
self._basis1 = bas
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are these properties if we do nothing special in the getter or setter? why not just use self.basis1 instead of self._basis1?

src/nemos/basis/basis.py Outdated Show resolved Hide resolved
@BalzaniEdoardo BalzaniEdoardo merged commit e8a62e8 into development Dec 18, 2024
13 checks passed
@BalzaniEdoardo BalzaniEdoardo deleted the set_shape_basis_method branch December 18, 2024 22:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants